from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Optional, TypedDict, Union
from typing_extensions import Self

from conduit.data import TernarySample
from conduit.data.datasets import CdtDataLoader
from conduit.data.datasets.vision import ImageTform
from loguru import logger
import numpy as np
import numpy.typing as npt
from ranzen import gcopy
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from tqdm import tqdm

from src.data import DataModule, labels_to_group_id

__all__ = ["Encodings", "generate_encodings"]


@dataclass
class Encodings:
    """Result of encoding the data."""

    train: Tensor  # This is used to compute the pre-defined starting points.
    train_labels: Tensor  # Same as above.
    dep: Tensor  # This is used to compute the other starting points.
    test: npt.NDArray  # This is used for evaluation.
    test_labels: npt.NDArray[np.int32]  # Same as above.

    def normalize_(self, p: float = 2) -> None:
        self.train = F.normalize(self.train, dim=1, p=p)
        self.dep = F.normalize(self.dep, dim=1, p=p)
        self.test = F.normalize(torch.as_tensor(self.test), dim=1, p=p).numpy()

    def save(self, fpath: Union[Path, str]) -> None:
        fpath = Path(fpath)
        logger.info(f"Saving encodings to '{fpath.resolve()}'")
        data = {k: v if isinstance(v, np.ndarray) else v.numpy() for k, v in asdict(self).items()}
        np.savez_compressed(Path(fpath), **data)
        logger.info("Done.")

    @property
    def to_cluster(self) -> npt.NDArray:
        return torch.cat([self.train, self.dep], dim=0).numpy()

    @classmethod
    def from_npz(cls, fpath: Union[Path, str]) -> Self:
        logger.info("Loading encodings from file...")
        with Path(fpath).open("rb") as f:
            loaded: NpzContent = np.load(f)
            enc = cls(
                train=torch.from_numpy(loaded["train"]),
                train_labels=torch.from_numpy(loaded["train_labels"]),
                dep=torch.from_numpy(loaded["dep"]),
                test=loaded["test"],
                test_labels=loaded["test_labels"],
            )
        return enc


@torch.no_grad()
def generate_encodings(
    dm: DataModule,
    *,
    encoder: nn.Module,
    device: Union[str, torch.device],
    batch_size_tr: Optional[int] = None,
    batch_size_te: Optional[int] = None,
    transforms: Optional[ImageTform] = None,
    save_path: Union[Path, str, None] = None,
) -> Encodings:
    """Generate encodings by putting the data through a pre-trained model."""
    dm = gcopy(dm, deep=False)
    if transforms is not None:
        dm.set_transforms_all(transforms)

    encoder.to(device)

    train_enc, train_group_ids = encode_with_group_ids(
        encoder, dl=dm.train_dataloader(eval=True, batch_size=batch_size_tr), device=device
    )
    deployment_enc, _ = encode_with_group_ids(
        encoder, dl=dm.deployment_dataloader(eval=True, batch_size=batch_size_te), device=device
    )
    test_enc, test_group_ids = encode_with_group_ids(
        encoder, dl=dm.test_dataloader(), device=device
    )
    torch.cuda.empty_cache()

    encodings = Encodings(
        train=train_enc,
        train_labels=train_group_ids,
        dep=deployment_enc,
        test=test_enc.numpy(),
        test_labels=test_group_ids.numpy(),
    )
    if save_path is not None:
        encodings.save(save_path)
    return encodings


@torch.no_grad()
def encode_with_group_ids(
    model: nn.Module, *, dl: CdtDataLoader[TernarySample[Tensor]], device: Union[str, torch.device]
) -> tuple[Tensor, Tensor]:
    model.to(device)
    encoded: list[Tensor] = []
    group_ids: list[Tensor] = []
    with torch.no_grad():
        for sample in tqdm(dl, total=len(dl), desc="Encoding dataset"):
            enc = model(sample.x.to(device, non_blocking=True)).detach()
            encoded.append(enc.cpu())
            group_ids.append(labels_to_group_id(s=sample.s, y=sample.y, s_count=2))
    logger.info("Done.")
    return torch.cat(encoded, dim=0), torch.cat(group_ids, dim=0)


class NpzContent(TypedDict):
    """Content of the npz file (which is basically a dictionary)."""

    train: npt.NDArray
    train_labels: npt.NDArray
    dep: npt.NDArray
    test: npt.NDArray
    test_labels: npt.NDArray[np.int32]
